ScatterElements ================= 返回一个新tensor,根据指定索引和更新值对input中的元素进行指定操作(替换或相加)。不支持隐式类型转换。举例:一个三维输入tensor的返回为: .. code-block:: python output[indices[i][j][k]][j][k] = updates[i][j][k] #if axis == 0, reduction == "none" output[i][indices[i][j][k]][k] += updates[i][j][k] #if axis == 1, reduction == "add" output[i][j][indices[i][j][k]] = updates[i][j][k] #if axis == 2, reduction == "none" 输入: - **input** - 输入数据的地址 - **indices** - 指定索引。 - **updates** - 更新值。 - **param** - 算子计算所需参数的结构体。其各成员见下述。 - **core_mask** - 核掩码。 **ScatterElementsParameter定义:** .. code-block:: c :linenos: typedef struct ScatterElementsParameter { int* indices_stride_; // 对应于indices数组每一维度的步长 int* output_stride_; // 对应于output数组每一维度的步长 int input_dims_; // 输入张量的维度数 int axis_; // 指定索引所在的轴 int input_axis_size_; // 索引所在轴的元素数 int indices_total_num_; // indices数组的总元素数 int input_total_num_; // input数组的总元素数 int reduction_type_; // 规约类型,0代表none,1代表add } ScatterElementsParameter; 输出: - **output** - 输出地址。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持int8, int16, int32, fp32, fp64, cplx64, cplx128 - MT7004 支持fp16, fp32, int16, int32, cplx64 - 如果 indices 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。 - 如果 indices 的值超出 input 索引上下界,则相应的 updates 不会更新到 input,也不会抛出索引错误。 **共享存储版本:** .. c:function:: void i8_scatter_elements_s(int8_t* input, int8_t* output, int* indices, int8_t* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void i16_scatter_elements_s(int16_t* input, int16_t* output, int* indices, int16_t* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void i32_scatter_elements_s(int* input, int* output, int* indices, int* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void hp_scatter_elements_s(half* input, half* output, int* indices, half* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void fp_scatter_elements_s(float* input, float* output, int* indices, float* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void dp_scatter_elements_s(double* input, double* output, int* indices, double* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void c64_scatter_elements_s(float* input, float* output, int* indices, float* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void c128_scatter_elements_s(double* input, double* output, int* indices, double* updates, ScatterElementsParameter* param, int core_mask) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 40 void PackParam(ScatterElementsParameter* param, int* indices_shape, int* input_shape) { param->indices_stride_[param->input_dims_ - 1] = 1; int i; for (i = param->input_dims_ - 1; i > 0; --i) { param->indices_stride_[i - 1] = param->indices_stride_[i] * indices_shape[i]; } param->output_stride_[param->input_dims_ - 1] = 1; for (i = param->input_dims_ - 1; i > 0; --i) { param->output_stride_[i - 1] = param->output_stride_[i] * input_shape[i]; } param->indices_total_num_ = 1; for (i = 0; i < param->input_dims_; i++) { param->indices_total_num_ *= indices_shape[i]; } param->input_total_num_ = 1; for (i = 0; i < param->input_dims_; i++) { param->input_total_num_ *= input_shape[i]; } param->input_axis_size_ = input_shape[param->axis_]; } void TestScatterElementsSMC(int* input_shape, int* indices_shape, int ndim, int axis, int reduction_type, int core_mask) { int core_num = GetCoreNum(core_mask); int core_id = get_core_id(); int logic_core_id = GetLogicCoreId(core_mask, core_id); void* input_data = (void*)0x88000000; void* output_data = (void*)0x98000000; int* indices_data = (int*)0xA8000000; void* updates_data = (void*)0xB8000000; ScatterElementsParameter* param = (ScatterElementsParameter*)0xC8000000; if (logic_core_id == 0) { param->axis_ = axis; param->input_dims_ = ndim; param->indices_stride_ = (int*)0xC8020000; param->output_stride_ = (int*)0xC8040000; param->reduction_type_ = reduction_type; PackParam(param, indices_shape, input_shape); } sys_bar(0, core_num); // 初始化参数完成后进行同步 fp_scatter_elements_s(input_data, output_data, indices_data, updates_data, param, core_mask); } void main() { int input_shape[2] = {8, 30}; int indices_shape[2] = {3, 3}; int ndim = 2; int axis = 0; int reduction_type = 0; int core_mask = 0b1111; TestScatterElementsSMC(input_shape, indices_shape, ndim, axis, reduction_type, core_mask); } **私有存储版本:** .. c:function:: void i8_scatter_elements_p(int8_t* input, int8_t* output, int* indices, int8_t* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void i16_scatter_elements_p(int16_t* input, int16_t* output, int* indices, int16_t* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void i32_scatter_elements_p(int* input, int* output, int* indices, int* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void hp_scatter_elements_p(half* input, half* output, int* indices, half* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void fp_scatter_elements_p(float* input, float* output, int* indices, float* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void dp_scatter_elements_p(double* input, double* output, int* indices, double* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void c64_scatter_elements_p(float* input, float* output, int* indices, float* updates, ScatterElementsParameter* param, int core_mask) .. c:function:: void c128_scatter_elements_p(double* input, double* output, int* indices, double* updates, ScatterElementsParameter* param, int core_mask) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 34 void PackParam(ScatterElementsParameter* param, int* indices_shape, int* input_shape) { param->indices_stride_[param->input_dims_ - 1] = 1; int i; for (i = param->input_dims_ - 1; i > 0; --i) { param->indices_stride_[i - 1] = param->indices_stride_[i] * indices_shape[i]; } param->output_stride_[param->input_dims_ - 1] = 1; for (i = param->input_dims_ - 1; i > 0; --i) { param->output_stride_[i - 1] = param->output_stride_[i] * input_shape[i]; } param->indices_total_num_ = 1; for (i = 0; i < param->input_dims_; i++) { param->indices_total_num_ *= indices_shape[i]; } param->input_total_num_ = 1; for (i = 0; i < param->input_dims_; i++) { param->input_total_num_ *= input_shape[i]; } param->input_axis_size_ = input_shape[param->axis_]; } void TestScatterElementsL2(int* input_shape, int* indices_shape, int ndim, int axis, int reduction_type, int core_mask) { void* input_data = (void*)0x10000000; // 私有存储版本地址设置在AM内 void* output_data = (void*)0x10001000; int* indices_data = (int*)0x10002000; void* updates_data = (void*)0x10003000; ScatterElementsParameter* param = (ScatterElementsParameter*)0x10004000; param->axis_ = axis; param->input_dims_ = ndim; param->indices_stride_ = (int*)0x10005000; param->output_stride_ = (int*)0x10006000; param->reduction_type_ = reduction_type; PackParam(param, indices_shape, input_shape); fp_scatter_elements_p(input_data, output_data, indices_data, updates_data, param, core_mask); } void main() { int input_shape[2] = {8, 30}; int indices_shape[2] = {3, 3}; int ndim = 2; int axis = 0; int reduction_type = 0; int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动 TestScatterElementsL2(input_shape, indices_shape, ndim, axis, reduction_type, core_mask); }